package com.carol.bigdata.task.model.algo

import org.apache.spark.ml.classification.{DecisionTreeClassifier, OneVsRest, OneVsRestModel}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.tuning.{CrossValidator, CrossValidatorModel, ParamGridBuilder}
import org.apache.spark.ml.{Pipeline, PipelineModel}

class DT extends ModelTrait {

    // 决策树默认配置
    var maxDepth: Int = 10
    var maxBins: Int = 32
    var minInfoGain: Double = 0.0
    val labelCol: String = "label"
    val featuresCol: String = "features"
    val rawPredictionCol: String = "rawPrediction"
    val predictionCol: String = "prediction"
    val tree: DecisionTreeClassifier = new DecisionTreeClassifier()
        .setMaxDepth(maxDepth)
        .setMaxBins(maxBins)
        .setMinInfoGain(minInfoGain)
        .setFeaturesCol(featuresCol)
        .setLabelCol(labelCol)
        .setRawPredictionCol(rawPredictionCol)
        .setPredictionCol(predictionCol)

    // 交叉验证配置
    var tuneMaxDepth: Array[Int] = Array(5, 8)
    var tuneMaxBins: Array[Int] = Array(4, 16)
    var tuneMinInfoGain: Array[Double] = Array(0.1, 0.2)


    // 配置模型默认参数
    override def init(params: Map[String, Any]): Unit = {
        maxDepth = params.getOrElse("maxDepth", maxDepth.asInstanceOf[Any]).asInstanceOf[Int]
        maxBins = params.getOrElse("maxBins", maxBins.asInstanceOf[Any]).asInstanceOf[Int]
        minInfoGain = params.getOrElse("minInfoGain", minInfoGain.asInstanceOf[Any]).asInstanceOf[Double]
        tuneMaxDepth = params.getOrElse("tuneMaxDepth", tuneMaxDepth.asInstanceOf[Any]).asInstanceOf[Array[Int]]
        tuneMaxBins = params.getOrElse("tuneMaxBins", tuneMaxBins.asInstanceOf[Any]).asInstanceOf[Array[Int]]
        tuneMinInfoGain = params.getOrElse("tuneMinInfoGain", tuneMinInfoGain.asInstanceOf[Any]).asInstanceOf[Array[Double]]
    }


    // 构建pipeline模型
    override def buildPipeline(featuresCol: String = "features",
                               labelCol: String = "label",
                               rawPredictionCol: String = "rawPrediction",
                               predictionCol: String = "prediction",
                               objective: String = "binary",
                               numClass: Int = 2): Pipeline = {
        // 设置已调优的参数
        val model: DecisionTreeClassifier = tree
            .setMaxDepth(maxDepth)
            .setMaxBins(maxBins)
            .setMinInfoGain(minInfoGain)
            .setFeaturesCol(featuresCol)
            .setLabelCol(labelCol)
            .setRawPredictionCol(rawPredictionCol)
            .setPredictionCol(predictionCol)
        println(model.extractParamMap)
        println(model.explainParams)

        val ovrTree: OneVsRest = buildOvrTree(model,
            featuresCol, labelCol, rawPredictionCol, predictionCol, objective)

        // 构建pipeline
        val pipeline: Pipeline = new Pipeline()
            .setStages(if (objective == "binary") Array(model) else Array(ovrTree))

        pipeline
    }


    // 交叉验证
    override def buildValidator(pipeline: Pipeline,
                                seed: Int = 1,
                                numFolds: Int = 2,
                                parallelNum: Int = 2,
                                objective: String = "binary"): CrossValidator = {
        val gridBuilder = new ParamGridBuilder()
            .addGrid(tree.maxDepth, tuneMaxDepth)
            .addGrid(tree.maxBins, tuneMaxBins)
            .addGrid(tree.minInfoGain, tuneMinInfoGain)

        val cv = buildValidatorFromGrid(pipeline, gridBuilder, seed, numFolds, parallelNum, objective)

        cv
    }


    // 更新微调参数
    override def updateTuneParams(bestParamsMap: ParamMap): Unit = {
        maxDepth = bestParamsMap.getOrElse(tree.maxDepth, maxDepth)
        maxBins = bestParamsMap.getOrElse(tree.maxBins, maxBins)
        minInfoGain = bestParamsMap.getOrElse(tree.minInfoGain, minInfoGain)
    }
    override def updateTuneParamsFromCV(crossModel: CrossValidatorModel,
                                        maxBy: Boolean = true,
                                        objective: String = "binary"): Unit = {
        val paramMapScores = crossModel.getEstimatorParamMaps.zip(crossModel.avgMetrics)
        val bestParamsMap = if (maxBy) paramMapScores.maxBy(_._2)._1 else paramMapScores.minBy(_._2)._1
        val bestModelId = {
            if (objective.toLowerCase.contains("binary")) crossModel.bestModel.asInstanceOf[PipelineModel].stages.last.uid
            else crossModel.bestModel.asInstanceOf[PipelineModel].stages.last.asInstanceOf[OneVsRestModel].getClassifier.uid
        }
        val dthParam = new Param[Int](bestModelId, name = "maxDepth", doc = "maxDepth")
        val binParam = new Param[Int](bestModelId, name = "maxBins", doc = "maxBins")
        val gainParam = new Param[Double](bestModelId, name = "minInfoGain", doc = "minInfoGain")
        maxDepth = bestParamsMap.getOrElse(dthParam, maxDepth)
        maxBins = bestParamsMap.getOrElse(binParam, maxBins)
        minInfoGain = bestParamsMap.getOrElse(gainParam, minInfoGain)
    }


}
